/*  */
#ifndef skynet_malloc
#define skynet_malloc malloc
#endif
#ifndef skynet_free
#define skynet_free free
#endif

#ifndef skynet_socket_h
#define SKYNET_SOCKET_TYPE_DATA 1
#define SKYNET_SOCKET_TYPE_CONNECT 2
#define SKYNET_SOCKET_TYPE_CLOSE 3
#define SKYNET_SOCKET_TYPE_ACCEPT 4
#define SKYNET_SOCKET_TYPE_ERROR 5
#define SKYNET_SOCKET_TYPE_UDP 6
#endif
/*  */

#include <lua.h>
#include <lauxlib.h>

#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#define QUEUESIZE 1024
#define HASHSIZE 4096
#define SMALLSTRING 2048

#define TYPE_DATA 1
#define TYPE_MORE 2
#define TYPE_ERROR 3
#define TYPE_OPEN 4
#define TYPE_CLOSE 5

/*
        Each package is uint16 + data , uint16 (serialized in big-endian) is the
   number of bytes comprising the data .
 */

struct netpack {
  int id;
  int size;
  void *buffer;
};

struct uncomplete {
  struct netpack pack;
  struct uncomplete *next;
  int read;
  int header;
};

struct queue {
  int cap;
  int head;
  int tail;
  struct uncomplete *hash[HASHSIZE];
  struct netpack queue[QUEUESIZE];
};

static void clear_list(struct uncomplete *uc) {
  while (uc) {
    void *tmp = uc;
    uc = uc->next;
    skynet_free(tmp);
  }
}

static int lclear(lua_State *L) {
  struct queue *q = lua_touserdata(L, 1);
  if (q == NULL) {
    return 0;
  }
  int i;
  for (i = 0; i < HASHSIZE; i++) {
    clear_list(q->hash[i]);
    q->hash[i] = NULL;
  }
  if (q->head > q->tail) {
    q->tail += q->cap;
  }
  for (i = q->head; i < q->tail; i++) {
    struct netpack *np = &q->queue[i % q->cap];
    skynet_free(np->buffer);
  }
  q->head = q->tail = 0;

  return 0;
}

static inline int hash_fd(int fd) {
  int a = fd >> 24;
  int b = fd >> 12;
  int c = fd;
  return (int)(((uint32_t)(a + b + c)) % HASHSIZE);
}

static struct uncomplete *find_uncomplete(struct queue *q, int fd) {
  if (q == NULL)
    return NULL;
  int h = hash_fd(fd);
  struct uncomplete *uc = q->hash[h];
  if (uc == NULL)
    return NULL;
  if (uc->pack.id == fd) {
    q->hash[h] = uc->next;
    return uc;
  }
  struct uncomplete *last = uc;
  while (last->next) {
    uc = last->next;
    if (uc->pack.id == fd) {
      last->next = uc->next;
      return uc;
    }
    last = uc;
  }
  return NULL;
}

static struct queue *get_queue(lua_State *L) {
  struct queue *q = lua_touserdata(L, 1);
  if (q == NULL) {
    q = lua_newuserdata(L, sizeof(struct queue));
    q->cap = QUEUESIZE;
    q->head = 0;
    q->tail = 0;
    int i;
    for (i = 0; i < HASHSIZE; i++) {
      q->hash[i] = NULL;
    }
    lua_replace(L, 1);
  }
  return q;
}

static void expand_queue(lua_State *L, struct queue *q) {
  struct queue *nq = lua_newuserdata(L, sizeof(struct queue) +
                                            q->cap * sizeof(struct netpack));
  nq->cap = q->cap + QUEUESIZE;
  nq->head = 0;
  nq->tail = q->cap;
  memcpy(nq->hash, q->hash, sizeof(nq->hash));
  memset(q->hash, 0, sizeof(q->hash));
  int i;
  for (i = 0; i < q->cap; i++) {
    int idx = (q->head + i) % q->cap;
    nq->queue[i] = q->queue[idx];
  }
  q->head = q->tail = 0;
  lua_replace(L, 1);
}

static void push_data(lua_State *L, int fd, void *buffer, int size, int clone) {
  if (clone) {
    void *tmp = skynet_malloc(size);
    memcpy(tmp, buffer, size);
    buffer = tmp;
  }
  struct queue *q = get_queue(L);
  struct netpack *np = &q->queue[q->tail];
  if (++q->tail >= q->cap)
    q->tail -= q->cap;
  np->id = fd;
  np->buffer = buffer;
  np->size = size;
  if (q->head == q->tail) {
    expand_queue(L, q);
  }
}

static struct uncomplete *save_uncomplete(lua_State *L, int fd) {
  struct queue *q = get_queue(L);
  int h = hash_fd(fd);
  struct uncomplete *uc = skynet_malloc(sizeof(struct uncomplete));
  memset(uc, 0, sizeof(*uc));
  uc->next = q->hash[h];
  uc->pack.id = fd;
  q->hash[h] = uc;

  return uc;
}

static inline int read_size(uint8_t *buffer) {
  int r = (int)buffer[0] << 8 | (int)buffer[1];
  return r;
}

static void push_more(lua_State *L, int fd, uint8_t *buffer, int size) {
  if (size == 1) {
    struct uncomplete *uc = save_uncomplete(L, fd);
    uc->read = -1;
    uc->header = *buffer;
    return;
  }
  int pack_size = read_size(buffer);
  buffer += 2;
  size -= 2;

  if (size < pack_size) {
    struct uncomplete *uc = save_uncomplete(L, fd);
    uc->read = size;
    uc->pack.size = pack_size;
    uc->pack.buffer = skynet_malloc(pack_size);
    memcpy(uc->pack.buffer, buffer, size);
    return;
  }
  push_data(L, fd, buffer, pack_size, 1);

  buffer += pack_size;
  size -= pack_size;
  if (size > 0) {
    push_more(L, fd, buffer, size);
  }
}

static int filter_data_(lua_State *L, int fd, uint8_t *buffer, int size) {
  struct queue *q = lua_touserdata(L, 1);
  struct uncomplete *uc = find_uncomplete(q, fd);
  if (uc) {
    // fill uncomplete
    if (uc->read < 0) {
      // read size
      assert(uc->read == -1);
      int pack_size = *buffer;
      pack_size |= uc->header << 8;
      ++buffer;
      --size;
      uc->pack.size = pack_size;
      uc->pack.buffer = skynet_malloc(pack_size);
      uc->read = 0;
    }
    int need = uc->pack.size - uc->read;
    if (size < need) {
      memcpy(uc->pack.buffer + uc->read, buffer, size);
      uc->read += size;
      int h = hash_fd(fd);
      uc->next = q->hash[h];
      q->hash[h] = uc;
      return 1;
    }
    memcpy(uc->pack.buffer + uc->read, buffer, need);
    buffer += need;
    size -= need;
    if (size == 0) {
      lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
      lua_pushinteger(L, fd);
      lua_pushlightuserdata(L, uc->pack.buffer);
      lua_pushinteger(L, uc->pack.size);
      skynet_free(uc);
      return 5;
    }
    // more data
    push_data(L, fd, uc->pack.buffer, uc->pack.size, 0);
    skynet_free(uc);
    push_more(L, fd, buffer, size);
    lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
    return 2;
  } else {
    if (size == 1) {
      struct uncomplete *uc = save_uncomplete(L, fd);
      uc->read = -1;
      uc->header = *buffer;
      return 1;
    }
    int pack_size = read_size(buffer);
    buffer += 2;
    size -= 2;

    if (size < pack_size) {
      struct uncomplete *uc = save_uncomplete(L, fd);
      uc->read = size;
      uc->pack.size = pack_size;
      uc->pack.buffer = skynet_malloc(pack_size);
      memcpy(uc->pack.buffer, buffer, size);
      return 1;
    }
    if (size == pack_size) {
      // just one package
      lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
      lua_pushinteger(L, fd);
      void *result = skynet_malloc(pack_size);
      memcpy(result, buffer, size);
      lua_pushlightuserdata(L, result);
      lua_pushinteger(L, size);
      return 5;
    }
    // more data
    push_data(L, fd, buffer, pack_size, 1);
    buffer += pack_size;
    size -= pack_size;
    push_more(L, fd, buffer, size);
    lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
    return 2;
  }
}

static inline int filter_data(lua_State *L, int fd, uint8_t *buffer, int size) {
  int ret = filter_data_(L, fd, buffer, size);
  // buffer is the data of socket message, it malloc at socket_server.c :
  // function forward_message . it should be free before return,
  // skynet_free(buffer);
  return ret;
}

/*
        string msg | lightuserdata/integer

        lightuserdata/integer
 */

static const char *tolstring(lua_State *L, size_t *sz, int index) {
  const char *ptr;
  if (lua_isuserdata(L, index)) {
    ptr = (const char *)lua_touserdata(L, index);
    *sz = (size_t)luaL_checkinteger(L, index + 1);
  } else {
    ptr = luaL_checklstring(L, index, sz);
  }
  return ptr;
}

/*
        userdata queue
        lightuserdata msg
        integer size
        return
                userdata queue
                integer type
                integer fd
                string msg | lightuserdata/integer
 */
static int lfilter(lua_State *L) {
  size_t size;
  const char *buffer = tolstring(L, &size, 2);
  int fd = luaL_checkinteger(L, 4);
  lua_settop(L, 1);
  return filter_data(L, fd, (uint8_t *)buffer, size);
}

/*
        userdata queue
        return
                integer fd
                lightuserdata msg
                integer size
 */
static int lpop(lua_State *L) {
  struct queue *q = lua_touserdata(L, 1);
  if (q == NULL || q->head == q->tail)
    return 0;
  struct netpack *np = &q->queue[q->head];
  if (++q->head >= q->cap) {
    q->head = 0;
  }
  lua_pushinteger(L, np->id);
  lua_pushlightuserdata(L, np->buffer);
  lua_pushinteger(L, np->size);

  return 3;
}

static inline void write_size(uint8_t *buffer, int len) {
  buffer[0] = (len >> 8) & 0xff;
  buffer[1] = len & 0xff;
}

static int lpack(lua_State *L) {
  size_t len;
  const char *ptr = tolstring(L, &len, 1);
  if (len > 0x10000) {
    return luaL_error(L, "Invalid size (too long) of data : %d", (int)len);
  }

  uint8_t *buffer = skynet_malloc(len + 2);
  write_size(buffer, len);
  memcpy(buffer + 2, ptr, len);

  lua_pushlightuserdata(L, buffer);
  lua_pushinteger(L, len + 2);

  return 2;
}

static int lpack_string(lua_State *L) {
  uint8_t tmp[SMALLSTRING + 2];
  size_t len;
  uint8_t *buffer;
  const char *ptr = tolstring(L, &len, 1);
  if (len > 0x10000) {
    return luaL_error(L, "Invalid size (too long) of data : %d", (int)len);
  }

  if (len <= SMALLSTRING) {
    buffer = tmp;
  } else {
    buffer = lua_newuserdata(L, len + 2);
  }

  write_size(buffer, len);
  memcpy(buffer + 2, ptr, len);
  lua_pushlstring(L, (const char *)buffer, len + 2);

  return 1;
}

static int lpack_padding(lua_State *L) {
  uint8_t tmp[SMALLSTRING + 2];
  size_t content_sz;
  uint8_t *buffer;
  const char *ptr = tolstring(L, &content_sz, 2);
  size_t cookie_sz = 0;
  const char *cookie = luaL_checklstring(L, 1, &cookie_sz);
  size_t len = cookie_sz + content_sz;

  if (len > 0x10000) {
    return luaL_error(L, "Invalid size (too long) of data : %d", (int)len);
  }

  if (len <= SMALLSTRING) {
    buffer = tmp;
  } else {
    buffer = lua_newuserdata(L, len + 2);
  }

  write_size(buffer, len);
  memcpy(buffer + 2, ptr, content_sz);
  memcpy(buffer + 2 + content_sz, cookie, cookie_sz);
  lua_pushlstring(L, (const char *)buffer, len + 2);

  return 1;
}

static int ltostring(lua_State *L) {
  void *ptr = lua_touserdata(L, 1);
  int size = luaL_checkinteger(L, 2);
  if (ptr == NULL) {
    lua_pushliteral(L, "");
  } else {
    if (lua_isnumber(L, 3)) {
      int offset = lua_tointeger(L, 3);
      if (offset < 0) {
        return luaL_error(L, "Invalid offset %d", offset);
      }
      if (offset > size) {
        offset = size;
      }
      lua_pushlstring(L, (const char *)ptr + offset, size - offset);
    } else {
      lua_pushlstring(L, (const char *)ptr, size);
      skynet_free(ptr);
    }
  }
  return 1;
}

int luaopen_netpackxa(lua_State *L) {
  luaL_checkversion(L);
  luaL_Reg l[] = {
      {"pop", lpop},
      {"pack", lpack},
      {"pack_string", lpack_string},
      {"pack_padding", lpack_padding},
      {"clear", lclear},
      {"tostring", ltostring},
      {NULL, NULL},
  };
  luaL_newlib(L, l);

  // the order is same with macros : TYPE_* (defined top)
  lua_pushliteral(L, "data");
  lua_pushliteral(L, "more");
  lua_pushliteral(L, "error");
  lua_pushliteral(L, "open");
  lua_pushliteral(L, "close");

  lua_pushcclosure(L, lfilter, 5);
  lua_setfield(L, -2, "filter");

  return 1;
}
